# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import PreTrainedModel
from trl.models.utils import prepare_deepspeed
from trl.trainer.utils import selective_log_softmax


class RLHFTrainerMixin:

    def __init__(self,
                 model: Optional[Union[PreTrainedModel, nn.Module]] = None,
                 ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
                 *_args,
                 **kwargs):
        from trl.trainer import disable_dropout_in_model
        from swift.llm import HfConfigFactory
        self.ref_model = ref_model
        self._stored_metrics = defaultdict(lambda: defaultdict(list))
        args = kwargs['args']
        self.beta = getattr(args, 'beta', 0.0)
        if getattr(args, 'disable_dropout', False):
            disable_dropout_in_model(model)
            if self.ref_model is not None:
                disable_dropout_in_model(self.ref_model)

        self.is_encoder_decoder = kwargs['template'].is_encoder_decoder
        self._peft_has_been_casted_to_bf16 = False
        self.generate_during_eval = getattr(args, 'generate_during_eval', False)
        if self.is_encoder_decoder:
            self.decoder_start_token_id = HfConfigFactory.get_config_attr(model.config, 'decoder_start_token_id')
            self.pad_token_id = HfConfigFactory.get_config_attr(model.config, 'pad_token_id')
        # not use
        self.is_vision_model = False
        self.label_pad_token_id = -100
        self.use_dpo_data_collator = True
        super().__init__(model, *_args, **kwargs)
        self.aux_loss_enabled = model.model_info.is_moe_model and args.router_aux_loss_coef > 0
        self.aux_loss_coef = args.router_aux_loss_coef
        if ref_model is not None:
            if self.is_deepspeed_enabled:
                self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

        self.padding_value = self.tokenizer.pad_token_id

    def create_loss_and_metric(self, args):
        return {}

    def _prepare_inputs(self, inputs):
        inputs = super()._prepare_inputs(inputs)
        if self.template.sequence_parallel_size > 1:
            from swift.trainers.sequence_parallel import sequence_parallel
            sequence_parallel.prepare_inputs(inputs)
        return inputs

    def get_train_dataloader(self, *args, **kwargs):
        train_dataloader = super().get_train_dataloader(*args, **kwargs)
        base_dataloader = train_dataloader.base_dataloader if hasattr(
            train_dataloader, 'base_dataloader') and isinstance(train_dataloader.base_dataloader,
                                                                DataLoader) else train_dataloader
        if base_dataloader.worker_init_fn is not None and not isinstance(
                base_dataloader.worker_init_fn, partial) and 'num_workers' in inspect.signature(
                    base_dataloader.worker_init_fn).parameters:
            base_dataloader.worker_init_fn = partial(
                base_dataloader.worker_init_fn,
                num_workers=self.args.dataloader_num_workers,
                rank=self.args.process_index)
        return train_dataloader

    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        model_kwargs = batch.copy()
        labels = model_kwargs.pop('labels', None)
        if self.is_encoder_decoder:
            model_kwargs['labels'] = labels

        if self.aux_loss_enabled:
            model_kwargs['output_router_logits'] = True
        outputs = model(**model_kwargs, use_cache=False)
        model_kwargs['labels'] = labels
        model_kwargs['chosen_labels'] = torch.zeros(model_kwargs['labels'].shape[0] // 2)  # just get shape
        if outputs.logits.shape[1] != labels.shape[1]:
            # for llava, the model returns logits for the entire sequence, including the image tokens
            # (placed before the text tokens)
            outputs.logits = outputs.logits[:, -labels.shape[1]:]
        for key in ['input_ids', 'attention_mask', 'labels']:
            model_kwargs[f'concatenated_{key}'] = model_kwargs.pop(key, None)
        if self.__class__.__name__ == 'ORPOTrainer':  # Pass-through labels
            model_kwargs['concatenated_input_ids'] = model_kwargs['concatenated_labels']

        @contextmanager
        def _patch_concatenated_forward():
            _old_concatenated_inputs = self.concatenated_inputs
            _old_model_call = model.__class__.__call__
            self.concatenated_inputs = lambda *args, **kwargs: model_kwargs
            model.__class__.__call__ = lambda *args, **kwargs: outputs
            try:
                yield
            finally:
                self.concatenated_inputs = _old_concatenated_inputs
                model.__class__.__call__ = _old_model_call

        with _patch_concatenated_forward():
            return super().concatenated_forward(model, model_kwargs)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        res = super().compute_loss(model, inputs, return_outputs=return_outputs)
        # compat transformers>=4.46.*
        if num_items_in_batch is not None and self.model_accepts_loss_kwargs:
            loss = res[0] if return_outputs else res
            loss = loss / self.args.gradient_accumulation_steps
            return (loss, res[1:]) if return_outputs else loss
        return res

    def _get_train_sampler(self, train_dataset=None):
        get_train_sampler = super()._get_train_sampler
        parameters = inspect.signature(get_train_sampler).parameters
        kwargs = {'train_dataset': train_dataset} if 'train_dataset' in parameters else {}
        return get_train_sampler(**kwargs)

    def get_per_token_logps(
        self,
        logits: torch.FloatTensor,
        labels: torch.LongTensor,
        label_pad_token_id=-100,
        reduction='mean',
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if logits.shape[:-1] != labels.shape:
            raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}'
                             'and labels must have the same shape {labels.shape}')
        loss_mask = labels != label_pad_token_id
        labels = labels.clone()
        labels[~loss_mask] = 0
        if reduction == 'mean':
            reduce_logits = logits.mean(-1)
        elif reduction == 'sum':
            reduce_logits = logits.sum(-1)
        else:
            raise ValueError(f'Invalid reduction: {reduction}')
        if self.template.sequence_parallel_size == 1:
            # https://github.com/huggingface/trl/pull/2799
            # Reduce peak vram consumption with efficient selective log_softmax
            per_token_logps = selective_log_softmax(logits, labels)
            per_token_logps[~loss_mask] = 0
            reduce_logits[~loss_mask] = 0
            return per_token_logps, reduce_logits, loss_mask
        else:
            labels = labels.to(logits.device)
            loss_mask = loss_mask.to(logits.device)
            mean_logits = reduce_logits
            per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
            from swift.trainers.sequence_parallel.utils import GatherLoss
            from swift.trainers.sequence_parallel import sequence_parallel
            position_ids = sequence_parallel.real_position_ids
            total_per_token_logps, total_loss_mask = GatherLoss.apply(per_token_logps, loss_mask, 1, position_ids)
            total_mean_logits = sequence_parallel.gather(mean_logits, dim=1, position_ids=position_ids)
            if position_ids is not None and position_ids.min() == -1:
                _pos_mask = position_ids >= 0
                total_per_token_logps = total_per_token_logps[_pos_mask].contiguous()
                total_mean_logits = total_mean_logits[_pos_mask].contiguous()
                total_loss_mask = total_loss_mask[_pos_mask].contiguous()

            total_loss_mask = total_loss_mask.bool()
            total_per_token_logps = total_per_token_logps * (total_loss_mask)

            if total_per_token_logps.dim() == 1:
                total_per_token_logps = total_per_token_logps.unsqueeze(0)
                total_mean_logits = total_mean_logits.unsqueeze(0)
                total_loss_mask = total_loss_mask.unsqueeze(0)
            return total_per_token_logps, total_mean_logits, total_loss_mask
